{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'\\n实验目的：实现基于RNN的文本分类\\n\\n实验内容：\\n1）词嵌入初始化方式：随机embedding、加载glove\\n2）CNN/RNN的特征抽取\\n3）Dropout\\n\\n\\n参考：\\nhttps://arxiv.org/abs/1408.5882\\nhttps://github.com/yokusama/CNN_Sentence_Classification\\nhttps://torchtext.readthedocs.io/en/latest/\\nhttp://mlexplained.com/2018/02/08/a-comprehensive-tutorial-to-torchtext/\\nhttps://machinelearningmastery.com/sequence-classification-lstm-recurrent-neural-networks-python-keras/\\nhttps://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/02-intermediate/bidirectional_recurrent_neural_network/main.py#L39-L58\\n\\n'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "'''\n",
    "实验目的：实现基于RNN的文本分类\n",
    "\n",
    "实验内容：\n",
    "1）词嵌入初始化方式：随机embedding、加载glove\n",
    "2）CNN/RNN的特征抽取\n",
    "3）Dropout\n",
    "\n",
    "\n",
    "参考：\n",
    "https://arxiv.org/abs/1408.5882\n",
    "https://github.com/yokusama/CNN_Sentence_Classification\n",
    "https://torchtext.readthedocs.io/en/latest/\n",
    "http://mlexplained.com/2018/02/08/a-comprehensive-tutorial-to-torchtext/\n",
    "https://machinelearningmastery.com/sequence-classification-lstm-recurrent-neural-networks-python-keras/\n",
    "https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/02-intermediate/bidirectional_recurrent_neural_network/main.py#L39-L58\n",
    "\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "D:\\workspace\\nlp-beginner_solution\\Task2\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import time\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import sklearn\n",
    "print(os.getcwd())\n",
    "\n",
    "dir_all_data='data\\\\task2_all_data.tsv'\n",
    "\n",
    "BATCH_SIZE=10\n",
    "\n",
    "cpu=True   #True   False \n",
    "if cpu :\n",
    "    USE_CUDA = False\n",
    "    DEVICE = torch.device('cpu')\n",
    "else:\n",
    "    USE_CUDA = torch.cuda.is_available()\n",
    "    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "    torch.cuda.set_device(0)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "#从文件中读取数据\n",
    "data_all=pd.read_csv(dir_all_data,sep='\\t')\n",
    "#print(all_data.shape)    #(156060, 4)\n",
    "#print(all_data.keys())   #['PhraseId', 'SentenceId', 'Phrase', 'Sentiment']\n",
    "idx =np.arange(data_all.shape[0])\n",
    "#print(data_all.head())\n",
    "#print(type(idx))   #<class 'numpy.ndarray'>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "#shuffle，划分验证集、测试集,并保存\n",
    "seed=0\n",
    "np.random.seed(seed)\n",
    "#print(idx)\n",
    "np.random.shuffle(idx)  \n",
    "#print(idx)\n",
    "\n",
    "train_size=int(len(idx) * 0.6)\n",
    "test_size =int(len(idx) * 0.8)\n",
    "\n",
    "data_all.iloc[idx[:train_size], :].to_csv('data/task2_train.csv',index=False)\n",
    "data_all.iloc[idx[train_size:test_size], :].to_csv(\"data/task2_test.csv\", index=False)\n",
    "data_all.iloc[idx[test_size:], :].to_csv(\"data/task2_dev.csv\", index=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Torchtext采用声明式方法加载数据\n",
    "from torchtext import data\n",
    "PAD_TOKEN='<pad>'\n",
    "TEXT = data.Field(sequential=True,batch_first=True, lower=True, pad_token=PAD_TOKEN)\n",
    "LABEL = data.Field(sequential=False, batch_first=True, unk_token=None)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "#读取数据\n",
    "\n",
    "datafields = [(\"PhraseId\", None), # 不需要的filed设置为None\n",
    "              (\"SentenceId\", None),\n",
    "              ('Phrase', TEXT),\n",
    "              ('Sentiment', LABEL)]\n",
    "train_data = data.TabularDataset(path='data/task2_train2.csv', format='csv',\n",
    "                                fields=datafields)\n",
    "dev_data  = data.TabularDataset(path='data/task2_dev.csv', format='csv',\n",
    "                                fields=datafields)\n",
    "test_data = data.TabularDataset(path='data/task2_test2.csv', format='csv',\n",
    "                                fields=datafields)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "#构建词典，字符映射到embedding\n",
    "#TEXT.vocab.vectors 就是词向量\n",
    "TEXT.build_vocab(train_data,  vectors= 'glove.6B.50d',   #可以提前下载好\n",
    "                 unk_init= lambda x:torch.nn.init.uniform_(x, a=-0.25, b=0.25))\n",
    "LABEL.build_vocab(train_data)\n",
    "\n",
    "#得到索引，PAD_TOKEN='<pad>'\n",
    "PAD_INDEX = TEXT.vocab.stoi[PAD_TOKEN]\n",
    "TEXT.vocab.vectors[PAD_INDEX] = 0.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "#构建迭代器\n",
    "train_iterator = data.BucketIterator(train_data, batch_size=BATCH_SIZE, \n",
    "                                     train=True, shuffle=True,device=DEVICE)\n",
    "\n",
    "dev_iterator = data.Iterator(dev_data, batch_size=len(dev_data), train=False,\n",
    "                         sort=False, device=DEVICE)\n",
    "\n",
    "test_iterator = data.Iterator(test_data, batch_size=len(test_data), train=False,\n",
    "                          sort=False, device=DEVICE)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "16525 6\n"
     ]
    }
   ],
   "source": [
    "#部分参数设置\n",
    "embedding_choice='glove'   #  'static'    'non-static'\n",
    "num_embeddings = len(TEXT.vocab)\n",
    "embedding_dim =50\n",
    "dropout_p=0.5\n",
    "hidden_size=50  #隐藏单元数\n",
    "num_layers=2  #层数\n",
    "\n",
    "vocab_size=len(TEXT.vocab)\n",
    "label_num=len(LABEL.vocab)\n",
    "print(vocab_size,label_num)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "class LSTM(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(LSTM,self).__init__()\n",
    "        \n",
    "        self.embedding_choice=embedding_choice        \n",
    "        self.hidden_size = hidden_size\n",
    "        self.num_layers = num_layers\n",
    "\n",
    "        \n",
    "        if self.embedding_choice==  'rand':\n",
    "            self.embedding=nn.Embedding(num_embeddings,embedding_dim)\n",
    "        if self.embedding_choice==  'glove':\n",
    "            self.embedding = nn.Embedding(num_embeddings, embedding_dim, \n",
    "                padding_idx=PAD_INDEX).from_pretrained(TEXT.vocab.vectors, freeze=True)\n",
    "        #input_size (输入的特征维度),hidden_size ,num_layers \n",
    "        self.lstm = nn.LSTM(embedding_dim, hidden_size, num_layers,\n",
    "                            batch_first=True,dropout=dropout_p,bidirectional=True)\n",
    "        self.dropout = nn.Dropout(dropout_p)    \n",
    "        self.fc = nn.Linear(hidden_size * 2, label_num)  # 2 for bidirection\n",
    "        \n",
    "        \n",
    "        \n",
    "    def forward(self,x):      # (Batch_size, Length) \n",
    "        # Set initial hidden and cell states \n",
    "        # h_n (num_layers * num_directions, batch, hidden_size)\n",
    "        h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size)\n",
    "        # c_n (num_layers * num_directions, batch, hidden_size): \n",
    "        c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size)\n",
    "        \n",
    "        if USE_CUDA:\n",
    "            h0=h0.cuda()\n",
    "            c0=c0.cuda()\n",
    "        \n",
    "        x=self.embedding(x)     #(Batch_size, Length) \n",
    "                                       #(Batch_size,  Length, Dimention) \n",
    "        \n",
    "        out, _ = self.lstm(x, (h0, c0))   #(Batch_size, Length，Dimention) \n",
    "                                        # (batch_size, Length, hidden_size)  \n",
    "        out=self.dropout(out)\n",
    "        \n",
    "        out = self.fc(out[:, -1, :])   # (batch_size, Length, hidden_size)  \n",
    "                           # (batch_size, label_num)  \n",
    "        return out "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "#构建模型\n",
    "\n",
    "model=LSTM()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)#创建优化器SGD\n",
    "criterion = nn.CrossEntropyLoss()   #损失函数\n",
    "\n",
    "if USE_CUDA:\n",
    "    model.cuda()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-14-b975ddac5cc2>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m     19\u001b[0m         \u001b[0mbatch_text\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mbatch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mPhrase\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     20\u001b[0m         \u001b[0mbatch_label\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mbatch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mSentiment\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 21\u001b[1;33m         \u001b[0mout\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mmodel\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbatch_text\u001b[0m\u001b[1;33m)\u001b[0m    \u001b[1;31m#[batch_size, label_num]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     22\u001b[0m         \u001b[0mloss\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mout\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbatch_label\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     23\u001b[0m         \u001b[0mtotal_loss\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtotal_loss\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0mloss\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\.conda\\envs\\python36\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m    491\u001b[0m             \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    492\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 493\u001b[1;33m             \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    494\u001b[0m         \u001b[1;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    495\u001b[0m             \u001b[0mhook_result\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m<ipython-input-10-49499159bfde>\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, x)\u001b[0m\n\u001b[0;32m     36\u001b[0m                                        \u001b[1;31m#(Batch_size,  Length, Dimention)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     37\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 38\u001b[1;33m         \u001b[0mout\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0m_\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlstm\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mh0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mc0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m   \u001b[1;31m#(Batch_size, Length，Dimention)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     39\u001b[0m                                         \u001b[1;31m# (batch_size, Length, hidden_size)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     40\u001b[0m         \u001b[0mout\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdropout\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mout\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\.conda\\envs\\python36\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m    491\u001b[0m             \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    492\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 493\u001b[1;33m             \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    494\u001b[0m         \u001b[1;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    495\u001b[0m             \u001b[0mhook_result\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\.conda\\envs\\python36\\lib\\site-packages\\torch\\nn\\modules\\rnn.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input, hx)\u001b[0m\n\u001b[0;32m    557\u001b[0m             \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward_packed\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    558\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 559\u001b[1;33m             \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward_tensor\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    560\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    561\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\.conda\\envs\\python36\\lib\\site-packages\\torch\\nn\\modules\\rnn.py\u001b[0m in \u001b[0;36mforward_tensor\u001b[1;34m(self, input, hx)\u001b[0m\n\u001b[0;32m    537\u001b[0m         \u001b[0munsorted_indices\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    538\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 539\u001b[1;33m         \u001b[0moutput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhidden\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward_impl\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbatch_sizes\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmax_batch_size\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msorted_indices\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    540\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    541\u001b[0m         \u001b[1;32mreturn\u001b[0m \u001b[0moutput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpermute_hidden\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mhidden\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0munsorted_indices\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\.conda\\envs\\python36\\lib\\site-packages\\torch\\nn\\modules\\rnn.py\u001b[0m in \u001b[0;36mforward_impl\u001b[1;34m(self, input, hx, batch_sizes, max_batch_size, sorted_indices)\u001b[0m\n\u001b[0;32m    520\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mbatch_sizes\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    521\u001b[0m             result = _VF.lstm(input, hx, self._get_flat_weights(), self.bias, self.num_layers,\n\u001b[1;32m--> 522\u001b[1;33m                               self.dropout, self.training, self.bidirectional, self.batch_first)\n\u001b[0m\u001b[0;32m    523\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    524\u001b[0m             result = _VF.lstm(input, batch_sizes, hx, self._get_flat_weights(), self.bias,\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "import time\n",
    "epoch=100\n",
    "best_accuracy=0.0\n",
    "start_time=time.time()\n",
    "\n",
    "for i in range(epoch):\n",
    "    model.train()\n",
    "    total_loss=0.0\n",
    "    accuracy=0.0\n",
    "    total_correct=0.0\n",
    "    total_data_num = len(train_iterator.dataset)\n",
    "    steps = 0.0\n",
    "    #训练\n",
    "    for batch in train_iterator:\n",
    "        steps+=1\n",
    "        #print(steps)\n",
    "        optimizer.zero_grad() #  梯度缓存清零\n",
    "        \n",
    "        batch_text=batch.Phrase\n",
    "        batch_label=batch.Sentiment\n",
    "        out=model(batch_text)    #[batch_size, label_num]\n",
    "        loss = criterion(out, batch_label)\n",
    "        total_loss = total_loss + loss.item() \n",
    "\n",
    "        loss.backward()\n",
    "        optimizer.step()        \n",
    "\n",
    "        correct = (torch.max(out, dim=1)[1]  #get the indices\n",
    "                   .view(batch_label.size()) == batch_label).sum()\n",
    "        total_correct = total_correct + correct.item()\n",
    "\n",
    "        if steps%100==0:\n",
    "            print(\"Epoch %d_%.3f%%:  Training average Loss: %f\"\n",
    "                      %(i, steps * train_iterator.batch_size*100/len(train_iterator.dataset),total_loss/steps))  \n",
    "\n",
    "    #验证\n",
    "    model.eval()\n",
    "    total_loss=0.0\n",
    "    accuracy=0.0\n",
    "    total_correct=0.0\n",
    "    total_data_num = len(dev_iterator.dataset)\n",
    "    steps = 0.0    \n",
    "    for batch in dev_iterator:\n",
    "        steps+=1\n",
    "        batch_text=batch.Phrase\n",
    "        batch_label=batch.Sentiment\n",
    "        out=model(batch_text)\n",
    "        loss = criterion(out, batch_label)\n",
    "        total_loss = total_loss + loss.item()\n",
    "        \n",
    "        correct = (torch.max(out, dim=1)[1].view(batch_label.size()) == batch_label).sum()\n",
    "        total_correct = total_correct + correct.item()\n",
    "        \n",
    "        print(\"Epoch %d :  Verification average Loss: %f, Verification accuracy: %f%%,Total Time:%f\"\n",
    "          %(i, total_loss/steps, total_correct*100/total_data_num,time.time()-start_time))  \n",
    "        \n",
    "        if best_accuracy < total_correct/total_data_num :\n",
    "            best_accuracy =total_correct/total_data_num \n",
    "            torch.save(model,'model_dict/model_lstm/epoch_%d_accuracy_%f'%(i,total_correct/total_data_num))\n",
    "            print('Model is saved in model_dict/model_lstm/epoch_%d_accuracy_%f'%(i,total_correct/total_data_num))\n",
    "            #torch.cuda.empty_cache()\n",
    "    break  #训练时去掉"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#测试-重新读取文件（方便重写成.py文件）\n",
    "PATH='model_dict/model_lstm/epoch_5_a'\n",
    "model = torch.load(PATH)\n",
    "\n",
    "total_loss=0.0\n",
    "accuracy=0.0\n",
    "total_correct=0.0\n",
    "total_data_num = len(train_iterator.dataset)\n",
    "steps = 0.0    \n",
    "start_time=time.time()\n",
    "for batch in test_iterator:\n",
    "    steps+=1\n",
    "    batch_text=batch.Phrase\n",
    "    batch_label=batch.Sentiment\n",
    "    out=model(batch_text)\n",
    "    loss = criterion(out, batch_label)\n",
    "    total_loss = total_loss + loss.item()\n",
    "\n",
    "    correct = (torch.max(out, dim=1)[1].view(batch_label.size()) == batch_label).sum()\n",
    "    total_correct = total_correct + correct.item()\n",
    "\n",
    "print(\"Test average Loss: %f, Test accuracy: %f，Total time: %f\"\n",
    "  %(total_loss/steps, total_correct/total_data_num,time.time()-start_time) ) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python36",
   "language": "python",
   "name": "python36"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
